#!/usr/bin/env python
# coding: utf-8

# In[1]:


import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torch.nn.functional as F
import asyncio
import pandas as pd
import numpy as np
import random
import math
import os

# Set the working directory
os.chdir("..")
print(os.getcwd())

# Set the random seed
random.seed(10012)


# In[2]:


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


# In[3]:


dataset_names = ['stwts', 'eo']

#embed_types = ['cvec_pca16', 'cvec_nmf16']#, 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal', 'lda100']
embed_types = ['cvec_pca16', 'cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal', 'lda100']
#embed_types = ['cvec_pca16','cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'lda100', 'bert']
#embed_types = ['lda100']


# In[4]:


class SparseModel(nn.Module):
  def __init__(self, hdim):
    super(SparseModel, self).__init__()
    self.hdim = hdim
    w = torch.zeros((hdim, hdim))
    torch.nn.init.xavier_normal_(w)
    w.requires_grad = True
    self.weights = nn.Parameter(w)

  def forward(self, input):
    x = torch.matmul(input, self.weights)
    return x

def custom_loss(output, input):
  return 0.5 * torch.square(torch.norm(input-output, p='fro')) / input.size()[1]


# In[5]:


learning_rate = 0.005
num_epochs = 20000
lmbda = 0.01

def Diff(li1, li2):
    return list(set(li1) - set(li2)) + list(set(li2) - set(li1))


# In[6]:


for q in range(len(dataset_names)):
    dataset_name = dataset_names[q]
    testset_list = pd.read_csv('data/output/'+ dataset_name+'_testset_list.csv')
    for j in range(len(embed_types)):
        for i in range(testset_list.shape[1]):
            print("Generating: "+ dataset_names[q] + " " + embed_types[j] + " iter" + str(i+1))
            data = pd.read_csv("data/output/" +dataset_name + '_' + embed_types[j] + '_full.csv', index_col=0)
            idx_test = testset_list["iter"+str(i+1)]
            idx_train = Diff(range(1, len(data)), idx_test)
            data = data.iloc[idx_train]
            data = torch.tensor(data.values)
            data = data.T
            print(data.shape)
            data = data.to(device)
            #max_data = torch.max(data)
            #min_data = torch.min(data)
            #data = (data - min_data) / (max_data - min_data)
            model = SparseModel(data.size()[1])
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
            loss_list = []
            prev_eval_loss = 99999
            for epoch in range(num_epochs):
              model.train()
              # ===================forward=====================
              output = model(data.float())
              reg21 = torch.sum(torch.norm(model.weights, p=2, dim=1))
              loss = custom_loss(output, data) + lmbda * reg21
              #print('L1 norm of L2 norm of weights: ', reg21.item())
              # ===================backward====================
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()
              # ===================log=======================
              #print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
              model.eval()
              eval_loss = F.l1_loss(model(data.float()), data.float()).item()
              loss_list.append(eval_loss)
              #print('Eval loss: ', eval_loss)
              if eval_loss > prev_eval_loss:
                break
              prev_eval_loss = eval_loss
            mw = model.weights.cpu().detach()
            np.save("data/output/" +'indices_'+ dataset_name +'_' + embed_types[j] + '_recon_iter' + str(i+1), torch.argsort(torch.norm(mw, p=2, dim=1), descending=True).numpy()[:3000])
            #torch.save(model.state_dict(), './sim_autoencoder.pth')


# In[7]:


'''## Parallel test

def background(f):
    def wrapped(*args, **kwargs):
        return asyncio.get_event_loop().run_in_executor(None, f, *args, **kwargs)
    return wrapped


@background
def recon_fun(q,i,j):
    dataset_name = dataset_names[q]
    testset_list = pd.read_csv(dataset_name+'_testset_list.csv')
    print("Generating: "+ dataset_names[q] + " " + embed_types[j] + " iter" + str(i+1))
    data = pd.read_csv(dataset_name + '_' + embed_types[j] + '_full.csv', index_col=0)
    idx_test = testset_list["iter"+str(i+1)]
    idx_train = Diff(range(1, len(data)), idx_test)
    data = data.iloc[idx_train]
    data = torch.tensor(data.values)
    data = data.T
    print(data.shape)
    data = data.to(device)
    #max_data = torch.max(data)
    #min_data = torch.min(data)
    #data = (data - min_data) / (max_data - min_data)
    model = SparseModel(data.size()[1])
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    loss_list = []
    prev_eval_loss = 99999
    for epoch in range(num_epochs):
      model.train()
      # ===================forward=====================
      output = model(data.float())
      reg21 = torch.sum(torch.norm(model.weights, p=2, dim=1))
      loss = custom_loss(output, data) + lmbda * reg21
      #print('L1 norm of L2 norm of weights: ', reg21.item())
      # ===================backward====================
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      # ===================log=======================
      #print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss.item()))
      model.eval()
      eval_loss = F.l1_loss(model(data.float()), data.float()).item()
      loss_list.append(eval_loss)
      #print('Eval loss: ', eval_loss)
      if eval_loss > prev_eval_loss:
        break
      prev_eval_loss = eval_loss
    mw = model.weights.cpu().detach()
    np.save('indices_'+ dataset_name +'_' + embed_types[j] + '_recon_iter' + str(i+1), torch.argsort(torch.norm(mw, p=2, dim=1), descending=True).numpy()[:3000])
    #torch.save(model.state_dict(), './sim_autoencoder.pth')

@background    
def test_fun(q,i,j):
    print('indices_'+ dataset_names[q] +'_' + embed_types[j] + '_recon_iter' + str(i+1))

'''


# In[8]:


'''for q in range(len(dataset_names)):
    for i in range(10):
        for j in range(len(embed_types)):
            recon_fun(q,i,j)

'''


# | Embedding type | learning rate | lambda |
# | --- | --- | --- |
# | cvec_tsne16 | 0.0005 | 0.1 |
# | cvec_pca16 | 0.0005 | 0.1 |
# | cvec_umap16 | 0.0005 | 0.1 |
# | cvec_nmf16 | 0.005 | 0.01 |
# | bert | 0.01 | 0.01 |
# | distil | 0.01 | 0.01 |
# | roberta | 0.01 | 0.01 |
# | glove6B | 0.01 | 0.01 |
# | universal | 0.05 | 0.001 |

# In[9]:


'''import itertools

# Get the combinations of elements (like expand.grid)
list1 = [range(len(dataset_names)), range(10), range(len(embed_types))]
combinations = [p for p in itertools.product(*list1)]
combos2 = [list(ele) for ele in combinations]
combos3 = pd.DataFrame(combos2, columns = ["q", "i", "j"])

print(combos3)

# Set up the multicore
import parmap
import multiprocessing as mp
pool = mp.Pool(mp.cpu_count()-4)

results = parmap.starmap(recon_fun, zip(combos3['q'].tolist(), combos3['i'].tolist(), combos3['j'].tolist()))

pool.close()'''

